{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Short example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings; warnings.simplefilter('ignore')\n",
    "import logging\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:<All keys matched successfully>\n",
      "INFO:root:Model input size: [80, 608]\n",
      "INFO:root:Using weights: m2d_vit_base-80x608p16x16-220930-mr7/checkpoint-300.pth\n",
      "INFO:root:Feature dimension: 3840\n",
      "INFO:root:Norm stats: -7.1, 4.2\n",
      "INFO:root:Runtime MelSpectrogram(16000, 400, 400, 160, 80, 50, 8000):\n",
      "INFO:root:MelSpectrogram(\n",
      "  Mel filter banks size = (80, 201), trainable_mel=False\n",
      "  (stft): STFT(n_fft=400, Fourier Kernel size=(201, 1, 400), iSTFT=False, trainable=False)\n",
      ")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " using 150 parameters, while dropped 250 out of 400 parameters from m2d_vit_base-80x608p16x16-220930-mr7/checkpoint-300.pth\n",
      " (dropped: ['mask_token', 'decoder_pos_embed', 'decoder_embed.weight', 'decoder_embed.bias', 'decoder_blocks.0.norm1.weight'] ...)\n",
      "<All keys matched successfully>\n"
     ]
    }
   ],
   "source": [
    "from portable_m2d import PortableM2D\n",
    "weight = 'm2d_vit_base-80x608p16x16-220930-mr7/checkpoint-300.pth'\n",
    "model = PortableM2D(weight_file=weight)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 63, 3840])\n"
     ]
    }
   ],
   "source": [
    "# A single sample of random waveform\n",
    "wav = torch.rand(1, 16000 * 10)\n",
    "\n",
    "# Encode with M2D\n",
    "with torch.no_grad():\n",
    "    embeddings = model(wav)\n",
    "\n",
    "# The output embeddings has a shape of [Batch, Frame, Dimension]\n",
    "print(embeddings.shape)  # --> torch.Size([1, 63, 3840])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ar",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
